#!/usr/bin/env python
# -*- coding: utf-8 -*-

from src.manager.log_manager import LogManager
from numpy import *
import operator

Logger = LogManager.get_logger(__name__)


class KnnManager:
    """
    k近邻算法
    """

    def classify0(self, in_x, data_set, labels, k):
        # data_set是numpy.ndarray类型，其shape属性是tuple类型，分别表示ndarray对象的行数和列数
        data_set_size = data_set.shape[0]
        # numpy.tile(A, reps)方法构建一个数组，这个数组是将参数A重复reps次，返回值为numpy.ndarray类型
        test_vector = tile(in_x, (data_set_size, 1))
        # 两个ndarray类型对象相减，表示两个矩阵相减
        diff_mat = test_vector - data_set
        # ndarray中，每个元素的二次方
        sq_diff_mat = diff_mat ** 2
        # 当axis为0时,是压缩行,即将每一列的元素相加,将矩阵压缩为一行
        # 当axis为1时,是压缩列,即将每一行的元素相加,将矩阵压缩为一列(这里的一列是为了方便理解说的，实际上，在控制台的输出中，仍然是以一行的形式输出的)
        sq_distances = sq_diff_mat.sum(axis=1)
        # 开平方
        distnaces = sq_distances ** 0.5
        # argsort函数返回的是数组值从小到大的索引值
        sorted_dist_indicies = distnaces.argsort()
        class_count = {}
        for i in range(k):
            vote_label = labels[sorted_dist_indicies[i]]
            class_count[vote_label] = class_count.get(vote_label, 0) + 1
        # dict.items() 函数以列表返回可遍历的(键, 值) 元组数组
        # operator模块提供的itemgetter函数用于获取对象的哪些维的数据，参数为一些序号（即需要获取的数据在对象中的序号）
        sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
        return sorted_class_count[0][0]
